[PKUWC2018]Minimax

2020-01-15
PKUWC

题意

给出一棵二叉树,即每个节点最多两个子节点

对于非叶子节点$i$,有$p_i$的概率选择子节点中权值较小的,有$1-p_i$的概率选择较大的

对于叶子节点,权值给出,且保证所有叶子节点的权值不相同

问最后根节点的权值情况

题解

先把权值离散化,对于节点cur,权值为i的概率为

考虑优化向上合并的过程,刚好可以用线段树合并,合并时顺便维护前缀和、后缀和即可

调试记录

线段树合并的时候先合并了左儿子,权值改变了,更新右儿子的时候出问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include <cstdio>
#include <algorithm>
#define LS a[a[cur].ls]
#define RS a[a[cur].rs]
using namespace std;
const int maxn = 3e5 + 5;
const int mo = 998244353;
int pow(int x, int t){
int res = 1; x %= mo;
while (t > 0){
if (t & 1) res = 1ll * res * x % mo;
x = 1ll * x * x % mo;
t >>= 1;
}
return res;
}
const int inv = pow(10000, mo - 2);
int link[maxn], cnt = 0;
struct T{
struct A{
int ls, rs, v, tg;
A(){tg = 1;}
}a[maxn * 40];
int rt[maxn], tot = 0;
void pushdown(int cur){
if (a[cur].tg == 1) return;
LS.v = 1ll * LS.v * a[cur].tg % mo;
RS.v = 1ll * RS.v * a[cur].tg % mo;
LS.tg = 1ll * LS.tg * a[cur].tg % mo;
RS.tg = 1ll * RS.tg * a[cur].tg % mo;
a[cur].tg = 1;
}
int upd(int cur, int l, int r, int p, int k){
if (!cur) cur = ++tot;
if (l == r){
a[cur].v = k;
return cur;
}
int mid = l + r >> 1; pushdown(cur);
if (p <= mid) a[cur].ls = upd(a[cur].ls, l, mid, p, k);
else a[cur].rs = upd(a[cur].rs, mid + 1, r, p, k);
a[cur].v = (LS.v + RS.v) % mo;
return cur;
}
int Merge(int cur, int v, int l, int r, int pfu, int sfu, int pfv, int sfv, int P){
if (!cur && !v) return 0;
pushdown(cur); pushdown(v);
if (!cur){
a[v].v = 1ll * a[v].v * (1ll * P * pfu % mo + 1ll * (1 + mo - P) * sfu % mo) % mo;
a[v].tg = 1ll * a[v].tg * (1ll * P * pfu % mo + 1ll * (1 + mo - P) * sfu % mo) % mo;
return v;
}
if (!v){
a[cur].v = 1ll * a[cur].v * (1ll * P * pfv % mo + 1ll * (1 + mo - P) * sfv % mo) % mo;
a[cur].tg = 1ll * a[cur].tg * (1ll * P * pfv % mo + 1ll * (1 + mo - P) * sfv % mo) % mo;
return cur;
}
int mid = l + r >> 1, t1 = LS.v, t2 = RS.v, t3 = a[a[v].ls].v, t4 = a[a[v].rs].v;
a[cur].ls = Merge(a[cur].ls, a[v].ls, l, mid, pfu, (sfu + t2) % mo, pfv, (sfv + t4) % mo, P);
a[cur].rs = Merge(a[cur].rs, a[v].rs, mid + 1, r, (pfu + t1) % mo, sfu, (pfv + t3) % mo, sfv, P);
a[cur].v = (LS.v + RS.v) % mo;
return cur;
}
int ans = 0;
void calc(int cur, int l, int r){
if (l == r){
ans = (ans + 1ll * l * link[l] % mo * a[cur].v % mo * a[cur].v % mo) % mo;
// printf("%d %d %d\n", l, link[l], a[cur].v);
return;
}
pushdown(cur);
int mid = l + r >> 1;
calc(a[cur].ls, l, mid);
calc(a[cur].rs, mid + 1, r);
}
}t;
struct E{
int to, nxt;
}e[maxn << 1];
int head[maxn], tot = 0;
void addedge(int u, int v){
e[++tot].to = v, e[tot].nxt = head[u];
head[u] = tot;
}
int p[maxn];
int v[maxn], tmp[maxn];
void dfs(int cur){
if (head[cur] == 0) t.rt[cur] = t.upd(t.rt[cur], 1, cnt, v[cur], 1);
for (int i = head[cur]; i; i = e[i].nxt){
dfs(e[i].to);
if (t.rt[cur] != 0) t.rt[cur] = t.Merge(t.rt[cur], t.rt[e[i].to], 1, cnt, 0, 0, 0, 0, 1ll * p[cur] * inv % mo);
else t.rt[cur] = t.rt[e[i].to];
}
} int n;
signed main(){
// freopen("1.in", "r", stdin);
scanf("%d", &n);
for (int fa, i = 1; i <= n; i++){
scanf("%d", &fa); addedge(fa, i);
}
for (int i = 1; i <= n; i++){
if (head[i] == 0) scanf("%d", v + i), tmp[++cnt] = v[i];
else scanf("%d", p + i);
} sort(tmp + 1, tmp + cnt + 1);
for (int i = 1; i <= n; i++)
if (head[i] == 0){
int p = lower_bound(tmp + 1, tmp + cnt + 1, v[i]) - tmp;
link[p] = v[i];
v[i] = p;
}
dfs(1);
t.calc(t.rt[1], 1, cnt);
printf("%d\n", t.ans);
return 0;
}